/*
 * Copyright (c) 2023-2024 elsfs Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.elsfs.cloud.screw.query.postgresql;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.elsfs.cloud.common.util.exception.QueryException;
import org.elsfs.cloud.common.util.lang.Assert;
import org.elsfs.cloud.common.util.lang.CollectionUtils;
import org.elsfs.cloud.common.util.lang.ExceptionUtils;
import org.elsfs.cloud.common.util.sql.JdbcUtils;
import org.elsfs.cloud.screw.constant.ScrewConstants;
import org.elsfs.cloud.screw.mapping.Mapping;
import org.elsfs.cloud.screw.metadata.Column;
import org.elsfs.cloud.screw.metadata.Database;
import org.elsfs.cloud.screw.metadata.PrimaryKey;
import org.elsfs.cloud.screw.query.AbstractDatabaseQuery;
import org.elsfs.cloud.screw.query.postgresql.model.PostgreSqlColumnModel;
import org.elsfs.cloud.screw.query.postgresql.model.PostgreSqlDatabaseModel;
import org.elsfs.cloud.screw.query.postgresql.model.PostgreSqlPrimaryKeyModel;
import org.elsfs.cloud.screw.query.postgresql.model.PostgreSqlTableModel;

/**
 * PostgreSql 查询
 *
 * @author zeng
 */
public class PostgreSqlDataBaseQuery extends AbstractDatabaseQuery {

  /**
   * 构造函数
   *
   * @param dataSource {@link DataSource}
   */
  public PostgreSqlDataBaseQuery(DataSource dataSource) {
    super(dataSource);
  }

  /**
   * 获取数据库
   *
   * @return {@link Database} 数据库信息
   */
  @Override
  public Database getDataBase() throws QueryException {
    PostgreSqlDatabaseModel model = new PostgreSqlDatabaseModel();
    // 当前数据库名称
    model.setDatabase(getCatalog());
    return model;
  }

  /**
   * 获取表信息
   *
   * @return {@link List} 所有表信息
   */
  @Override
  public List<PostgreSqlTableModel> getTables() throws QueryException {
    ResultSet resultSet = null;
    try {
      // 查询
      resultSet =
          getMetaData()
              .getTables(
                  getCatalog(), getSchema(), ScrewConstants.PERCENT_SIGN, new String[] {"TABLE"});
      // 映射
      return Mapping.convertList(resultSet, PostgreSqlTableModel.class);
    } catch (SQLException e) {
      throw ExceptionUtils.mpe(e);
    } finally {
      JdbcUtils.close(resultSet);
    }
  }

  /**
   * 获取列信息
   *
   * @param table {@link String} 表名
   * @return {@link List} 表字段信息
   */
  @Override
  public List<PostgreSqlColumnModel> getTableColumns(String table) throws QueryException {
    Assert.notEmpty(table, "Table name can not be empty!");
    ResultSet resultSet = null;
    try {
      // 查询
      resultSet =
          getMetaData().getColumns(getCatalog(), getSchema(), table, ScrewConstants.PERCENT_SIGN);
      // 映射
      List<PostgreSqlColumnModel> list =
          Mapping.convertList(resultSet, PostgreSqlColumnModel.class);
      // 这里处理是为了如果是查询全部列呢？所以处理并获取唯一表名
      List<String> tableNames =
          list.stream().map(PostgreSqlColumnModel::getTableName).toList().stream()
              .distinct()
              .toList();
      if (CollectionUtils.isEmpty(columnsCaching)) {
        // 查询全部
        if (table.equals(ScrewConstants.PERCENT_SIGN)) {
          // 获取全部表列信息SQL
          String sql =
              "SELECT \"TABLE_NAME\", \"TABLE_SCHEMA\", \"COLUMN_NAME\", \"LENGTH\","
                  + " concat(\"UDT_NAME\", case when \"LENGTH\" isnull then '' else concat('(',"
                  + " concat(\"LENGTH\", ')')) end) \"COLUMN_TYPE\" FROM(select table_schema as"
                  + " \"TABLE_SCHEMA\", column_name as \"COLUMN_NAME\", table_name as"
                  + " \"TABLE_NAME\", udt_name as \"UDT_NAME\", case when"
                  + " coalesce(character_maximum_length, numeric_precision, -1) = -1 then null else"
                  + " coalesce(character_maximum_length, numeric_precision, -1) end as \"LENGTH\""
                  + " from information_schema.columns a where  table_schema = '%s' and"
                  + " table_catalog = '%s') t";
          PreparedStatement statement =
              prepareStatement(String.format(sql, getSchema(), getDataBase().getDatabase()));
          resultSet = statement.executeQuery();
          int fetchSize = 4284;
          if (resultSet.getFetchSize() < fetchSize) {
            resultSet.setFetchSize(fetchSize);
          }
        } else {
          // 单表查询
          // 获取表列信息SQL 查询表名、列名、说明、数据类型
          String sql =
              "SELECT \"TABLE_NAME\", \"TABLE_SCHEMA\", \"COLUMN_NAME\", \"LENGTH\","
                  + " concat(\"UDT_NAME\", case when \"LENGTH\" isnull then '' else concat('(',"
                  + " concat(\"LENGTH\", ')')) end) \"COLUMN_TYPE\" FROM(select table_schema as"
                  + " \"TABLE_SCHEMA\", column_name as \"COLUMN_NAME\", table_name as"
                  + " \"TABLE_NAME\", udt_name as \"UDT_NAME\", case when"
                  + " coalesce(character_maximum_length, numeric_precision, -1) = -1 then null else"
                  + " coalesce(character_maximum_length, numeric_precision, -1) end as \"LENGTH\""
                  + " from information_schema.columns a where table_name = '%s' and table_schema ="
                  + " '%s' and table_catalog = '%s') t";
          resultSet =
              prepareStatement(String.format(sql, table, getSchema(), getDataBase().getDatabase()))
                  .executeQuery();
        }
        List<PostgreSqlColumnModel> inquires =
            Mapping.convertList(resultSet, PostgreSqlColumnModel.class);
        // 处理列，表名为key，列名为值
        tableNames.forEach(
            name ->
                columnsCaching.put(
                    name,
                    inquires.stream()
                        .filter(i -> i.getTableName().equals(name))
                        .collect(Collectors.toList())));
      }
      // 处理备注信息
      list.forEach(
          i -> {
            // 从缓存中根据表名获取列信息
            List<Column> columns = columnsCaching.get(i.getTableName());
            columns.forEach(
                j -> {
                  // 列名表名一致
                  if (i.getColumnName().equals(j.getColumnName())
                      && i.getTableName().equals(j.getTableName())) {
                    // 放入备注
                    i.setColumnLength(j.getColumnLength());
                    i.setColumnType(j.getColumnType());
                  }
                });
          });
      return list;
    } catch (SQLException e) {
      throw ExceptionUtils.mpe(e);
    } finally {
      JdbcUtils.close(resultSet);
    }
  }

  /**
   * 获取所有列信息
   *
   * @return {@link List} 表字段信息
   * @throws QueryException QueryException
   */
  @Override
  public List<? extends Column> getTableColumns() throws QueryException {
    return getTableColumns(ScrewConstants.PERCENT_SIGN);
  }

  /**
   * 根据表名获取主键
   *
   * @param table {@link String}
   * @return {@link List}
   * @throws QueryException QueryException
   */
  @Override
  public List<? extends PrimaryKey> getPrimaryKeys(String table) throws QueryException {
    ResultSet resultSet = null;
    try {
      // 查询
      resultSet = getMetaData().getPrimaryKeys(getCatalog(), getSchema(), table);
      // 映射
      return Mapping.convertList(resultSet, PostgreSqlPrimaryKeyModel.class);
    } catch (SQLException e) {
      throw ExceptionUtils.mpe(e);
    } finally {
      JdbcUtils.close(resultSet, this.connection);
    }
  }

  /**
   * 根据表名获取主键
   *
   * @return {@link List}
   * @throws QueryException QueryException
   */
  @Override
  public List<? extends PrimaryKey> getPrimaryKeys() throws QueryException {
    ResultSet resultSet = null;
    try {
      // 由于单条循环查询存在性能问题，所以这里通过自定义SQL查询数据库主键信息
      String sql =
          "SELECT result.TABLE_CAT, result.TABLE_SCHEM, result.TABLE_NAME, result.COLUMN_NAME,"
              + " result.KEY_SEQ, result.PK_NAME FROM(SELECT NULL AS TABLE_CAT, n.nspname AS"
              + " TABLE_SCHEM, ct.relname AS TABLE_NAME, a.attname AS COLUMN_NAME,"
              + " (information_schema._pg_expandarray(i.indkey)).n AS KEY_SEQ, ci.relname AS"
              + " PK_NAME, information_schema._pg_expandarray(i.indkey) AS KEYS, a.attnum AS"
              + " A_ATTNUM FROM pg_catalog.pg_class ct JOIN pg_catalog.pg_attribute a ON (ct.oid ="
              + " a.attrelid) JOIN pg_catalog.pg_namespace n ON (ct.relnamespace = n.oid) JOIN"
              + " pg_catalog.pg_index i ON (a.attrelid = i.indrelid) JOIN pg_catalog.pg_class ci ON"
              + " (ci.oid = i.indexrelid) WHERE true AND n.nspname = 'public' AND i.indisprimary)"
              + " result where result.A_ATTNUM = (result.KEYS).x ORDER BY result.table_name,"
              + " result.pk_name, result.key_seq";
      // 拼接参数
      resultSet = prepareStatement(sql).executeQuery();
      return Mapping.convertList(resultSet, PostgreSqlPrimaryKeyModel.class);
    } catch (SQLException e) {
      throw new QueryException(e);
    } finally {
      JdbcUtils.close(resultSet);
    }
  }
}
